#include <iostream>
#include <memory>
#include <string>

#include "yolox_detector.h"

static const std::vector<std::string> coco_names = {
    "person",        "bicycle",      "car",
    "motorcycle",    "airplane",     "bus",
    "train",         "truck",        "boat",
    "traffic light", "fire hydrant", "stop sign",
    "parking meter", "bench",        "bird",
    "cat",           "dog",          "horse",
    "sheep",         "cow",          "elephant",
    "bear",          "zebra",        "giraffe",
    "backpack",      "umbrella",     "handbag",
    "tie",           "suitcase",     "frisbee",
    "skis",          "snowboard",    "sports ball",
    "kite",          "baseball bat", "baseball glove",
    "skateboard",    "surfboard",    "tennis racket",
    "bottle",        "wine glass",   "cup",
    "fork",          "knife",        "spoon",
    "bowl",          "banana",       "apple",
    "sandwich",      "orange",       "broccoli",
    "carrot",        "hot dog",      "pizza",
    "donut",         "cake",         "chair",
    "couch",         "potted plant", "bed",
    "dining table",  "toilet",       "tv",
    "laptop",        "mouse",        "remote",
    "keyboard",      "cell phone",   "microwave",
    "oven",          "toaster",      "sink",
    "refrigerator",  "book",         "clock",
    "vase",          "scissors",     "teddy bear",
    "hair drier",    "toothbrush"};

static const float color_list[80][3] = {
    {0.000, 0.447, 0.741}, {0.850, 0.325, 0.098}, {0.929, 0.694, 0.125},
    {0.494, 0.184, 0.556}, {0.466, 0.674, 0.188}, {0.301, 0.745, 0.933},
    {0.635, 0.078, 0.184}, {0.300, 0.300, 0.300}, {0.600, 0.600, 0.600},
    {1.000, 0.000, 0.000}, {1.000, 0.500, 0.000}, {0.749, 0.749, 0.000},
    {0.000, 1.000, 0.000}, {0.000, 0.000, 1.000}, {0.667, 0.000, 1.000},
    {0.333, 0.333, 0.000}, {0.333, 0.667, 0.000}, {0.333, 1.000, 0.000},
    {0.667, 0.333, 0.000}, {0.667, 0.667, 0.000}, {0.667, 1.000, 0.000},
    {1.000, 0.333, 0.000}, {1.000, 0.667, 0.000}, {1.000, 1.000, 0.000},
    {0.000, 0.333, 0.500}, {0.000, 0.667, 0.500}, {0.000, 1.000, 0.500},
    {0.333, 0.000, 0.500}, {0.333, 0.333, 0.500}, {0.333, 0.667, 0.500},
    {0.333, 1.000, 0.500}, {0.667, 0.000, 0.500}, {0.667, 0.333, 0.500},
    {0.667, 0.667, 0.500}, {0.667, 1.000, 0.500}, {1.000, 0.000, 0.500},
    {1.000, 0.333, 0.500}, {1.000, 0.667, 0.500}, {1.000, 1.000, 0.500},
    {0.000, 0.333, 1.000}, {0.000, 0.667, 1.000}, {0.000, 1.000, 1.000},
    {0.333, 0.000, 1.000}, {0.333, 0.333, 1.000}, {0.333, 0.667, 1.000},
    {0.333, 1.000, 1.000}, {0.667, 0.000, 1.000}, {0.667, 0.333, 1.000},
    {0.667, 0.667, 1.000}, {0.667, 1.000, 1.000}, {1.000, 0.000, 1.000},
    {1.000, 0.333, 1.000}, {1.000, 0.667, 1.000}, {0.333, 0.000, 0.000},
    {0.500, 0.000, 0.000}, {0.667, 0.000, 0.000}, {0.833, 0.000, 0.000},
    {1.000, 0.000, 0.000}, {0.000, 0.167, 0.000}, {0.000, 0.333, 0.000},
    {0.000, 0.500, 0.000}, {0.000, 0.667, 0.000}, {0.000, 0.833, 0.000},
    {0.000, 1.000, 0.000}, {0.000, 0.000, 0.167}, {0.000, 0.000, 0.333},
    {0.000, 0.000, 0.500}, {0.000, 0.000, 0.667}, {0.000, 0.000, 0.833},
    {0.000, 0.000, 1.000}, {0.000, 0.000, 0.000}, {0.143, 0.143, 0.143},
    {0.286, 0.286, 0.286}, {0.429, 0.429, 0.429}, {0.571, 0.571, 0.571},
    {0.714, 0.714, 0.714}, {0.857, 0.857, 0.857}, {0.000, 0.447, 0.741},
    {0.314, 0.717, 0.741}, {0.50, 0.5, 0}};

void DrawObjects(const cv::Mat &image, const std::vector<Object> &objects) {
  for (const auto &obj : objects) {
    cv::Scalar box_color =
        cv::Scalar(color_list[obj.label][0], color_list[obj.label][1],
                   color_list[obj.label][2]);

    std::stringstream ss;
    ss << coco_names.at(obj.label) << ", " << std::setprecision(2) << std::fixed
       << obj.confidence * 100 << "%";
    const std::string text = ss.str();

    cv::Size text_size =
        cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, nullptr);
    const int x = obj.box.x;
    const int y = obj.box.y;

    cv::rectangle(image, obj.box, box_color * 255, 2);
    cv::rectangle(image,
                  cv::Rect(x, y - text_size.height, text_size.width + 5,
                           text_size.height + 5),
                  cv::Scalar(255, 255, 255), -1);

    cv::putText(image, text, cv::Point(x, y), cv::FONT_HERSHEY_SIMPLEX, 0.4,
                cv::Scalar(0, 0, 0), 1, cv::LINE_AA);
  }

  cv::imshow("YOLOX Object detection", image);
  cv::waitKey(10);
}

int main(int argc, char **argv) {
  if (argc < 3) {
    std::cout << "Usage: " << argv[0]
              << " </path/to/onnx/model> </path/to/video>\n";
    return -1;
  }

  const std::string model_path(argv[1]);
  const std::string video_path(argv[2]);

  std::unique_ptr<YoloxDetector> yolox_detector;
  yolox_detector.reset(new YoloxDetector(model_path));
  if (!yolox_detector->Init()) {
    std::cout << "Failed to initialize yolox detector!";
    return -1;
  }

  // video
  cv::VideoCapture video_capture;
  video_capture.open(video_path);
  if (!video_capture.isOpened()) {
    std::cout << "Can't open the video: " << video_path << std::endl;
    return -1;
  }

  int codec = cv::VideoWriter::fourcc('D', 'I', 'V', 'X');
  double fps = video_capture.get(cv::CAP_PROP_FPS);
  int width = video_capture.get(cv::CAP_PROP_FRAME_WIDTH);
  int height = video_capture.get(cv::CAP_PROP_FRAME_HEIGHT);
  cv::VideoWriter video_writer;
  video_writer.open("result.mp4", codec, fps, cv::Size(width, height));
  if (!video_writer.isOpened()) {
    std::cout << "Can't create output video file!\n";
    return -1;
  }

  cv::Mat input_image;
  while (true) {
    video_capture.read(input_image);
    if (input_image.empty()) {
      break;
    }
    std::vector<Object> objects;
    yolox_detector->Detect(input_image, &objects);
    DrawObjects(input_image, objects);
    video_writer.write(input_image);
  }

  video_capture.release();
  video_writer.release();

  return 0;
}